"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization.

A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code
at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have
been included here as pretrained models from their original .NPZ checkpoints.

Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and
extra padding support to allow porting of official Hybrid ResNet pretrained weights from
https://github.com/google-research/vision_transformer

Thanks to the Google team for the above two repositories and associated papers:
* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370
* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929
* Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237

Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
"""
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict  # pylint: disable=g-importing-member

import torch
import torch.nn as nn
from functools import partial

from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .registry import register_model
from .layers import (
    GroupNormAct,
    BatchNormAct2d,
    EvoNormBatch2d,
    EvoNormSample2d,
    ClassifierHead,
    DropPath,
    AvgPool2dSame,
    create_pool2d,
    StdConv2d,
    create_conv2d,
)


def _cfg(url="", **kwargs):
    return {
        "url": url,
        "num_classes": 1000,
        "input_size": (3, 224, 224),
        "pool_size": (7, 7),
        "crop_pct": 0.875,
        "interpolation": "bilinear",
        "mean": IMAGENET_INCEPTION_MEAN,
        "std": IMAGENET_INCEPTION_STD,
        "first_conv": "stem.conv",
        "classifier": "head.fc",
        **kwargs,
    }


default_cfgs = {
    # pretrained on imagenet21k, finetuned on imagenet1k
    "resnetv2_50x1_bitm": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz",
        input_size=(3, 448, 448),
        pool_size=(14, 14),
        crop_pct=1.0,
    ),
    "resnetv2_50x3_bitm": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz",
        input_size=(3, 448, 448),
        pool_size=(14, 14),
        crop_pct=1.0,
    ),
    "resnetv2_101x1_bitm": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz",
        input_size=(3, 448, 448),
        pool_size=(14, 14),
        crop_pct=1.0,
    ),
    "resnetv2_101x3_bitm": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz",
        input_size=(3, 448, 448),
        pool_size=(14, 14),
        crop_pct=1.0,
    ),
    "resnetv2_152x2_bitm": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz",
        input_size=(3, 448, 448),
        pool_size=(14, 14),
        crop_pct=1.0,
    ),
    "resnetv2_152x4_bitm": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz",
        input_size=(3, 480, 480),
        pool_size=(15, 15),
        crop_pct=1.0,
    ),  # only one at 480x480?
    # trained on imagenet-21k
    "resnetv2_50x1_bitm_in21k": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz",
        num_classes=21843,
    ),
    "resnetv2_50x3_bitm_in21k": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz",
        num_classes=21843,
    ),
    "resnetv2_101x1_bitm_in21k": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz",
        num_classes=21843,
    ),
    "resnetv2_101x3_bitm_in21k": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz",
        num_classes=21843,
    ),
    "resnetv2_152x2_bitm_in21k": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz",
        num_classes=21843,
    ),
    "resnetv2_152x4_bitm_in21k": _cfg(
        url="https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz",
        num_classes=21843,
    ),
    "resnetv2_50x1_bit_distilled": _cfg(
        url="https://storage.googleapis.com/bit_models/distill/R50x1_224.npz",
        interpolation="bicubic",
    ),
    "resnetv2_152x2_bit_teacher": _cfg(
        url="https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz",
        interpolation="bicubic",
    ),
    "resnetv2_152x2_bit_teacher_384": _cfg(
        url="https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz",
        input_size=(3, 384, 384),
        pool_size=(12, 12),
        crop_pct=1.0,
        interpolation="bicubic",
    ),
    "resnetv2_50": _cfg(interpolation="bicubic"),
    "resnetv2_50d": _cfg(interpolation="bicubic", first_conv="stem.conv1"),
    "resnetv2_50t": _cfg(interpolation="bicubic", first_conv="stem.conv1"),
    "resnetv2_101": _cfg(interpolation="bicubic"),
    "resnetv2_101d": _cfg(interpolation="bicubic", first_conv="stem.conv1"),
    "resnetv2_152": _cfg(interpolation="bicubic"),
    "resnetv2_152d": _cfg(interpolation="bicubic", first_conv="stem.conv1"),
}


def make_div(v, divisor=8):
    min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class PreActBottleneck(nn.Module):
    """Pre-activation (v2) bottleneck block.

    Follows the implementation of "Identity Mappings in Deep Residual Networks":
    https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua

    Except it puts the stride on 3x3 conv when available.
    """

    def __init__(
        self,
        in_chs,
        out_chs=None,
        bottle_ratio=0.25,
        stride=1,
        dilation=1,
        first_dilation=None,
        groups=1,
        act_layer=None,
        conv_layer=None,
        norm_layer=None,
        proj_layer=None,
        drop_path_rate=0.0,
    ):
        super().__init__()
        first_dilation = first_dilation or dilation
        conv_layer = conv_layer or StdConv2d
        norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
        out_chs = out_chs or in_chs
        mid_chs = make_div(out_chs * bottle_ratio)

        if proj_layer is not None:
            self.downsample = proj_layer(
                in_chs,
                out_chs,
                stride=stride,
                dilation=dilation,
                first_dilation=first_dilation,
                preact=True,
                conv_layer=conv_layer,
                norm_layer=norm_layer,
            )
        else:
            self.downsample = None

        self.norm1 = norm_layer(in_chs)
        self.conv1 = conv_layer(in_chs, mid_chs, 1)
        self.norm2 = norm_layer(mid_chs)
        self.conv2 = conv_layer(
            mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups
        )
        self.norm3 = norm_layer(mid_chs)
        self.conv3 = conv_layer(mid_chs, out_chs, 1)
        self.drop_path = (
            DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
        )

    def zero_init_last(self):
        nn.init.zeros_(self.conv3.weight)

    def forward(self, x):
        x_preact = self.norm1(x)

        # shortcut branch
        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x_preact)

        # residual branch
        x = self.conv1(x_preact)
        x = self.conv2(self.norm2(x))
        x = self.conv3(self.norm3(x))
        x = self.drop_path(x)
        return x + shortcut


class Bottleneck(nn.Module):
    """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT."""

    def __init__(
        self,
        in_chs,
        out_chs=None,
        bottle_ratio=0.25,
        stride=1,
        dilation=1,
        first_dilation=None,
        groups=1,
        act_layer=None,
        conv_layer=None,
        norm_layer=None,
        proj_layer=None,
        drop_path_rate=0.0,
    ):
        super().__init__()
        first_dilation = first_dilation or dilation
        act_layer = act_layer or nn.ReLU
        conv_layer = conv_layer or StdConv2d
        norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
        out_chs = out_chs or in_chs
        mid_chs = make_div(out_chs * bottle_ratio)

        if proj_layer is not None:
            self.downsample = proj_layer(
                in_chs,
                out_chs,
                stride=stride,
                dilation=dilation,
                preact=False,
                conv_layer=conv_layer,
                norm_layer=norm_layer,
            )
        else:
            self.downsample = None

        self.conv1 = conv_layer(in_chs, mid_chs, 1)
        self.norm1 = norm_layer(mid_chs)
        self.conv2 = conv_layer(
            mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups
        )
        self.norm2 = norm_layer(mid_chs)
        self.conv3 = conv_layer(mid_chs, out_chs, 1)
        self.norm3 = norm_layer(out_chs, apply_act=False)
        self.drop_path = (
            DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
        )
        self.act3 = act_layer(inplace=True)

    def zero_init_last(self):
        nn.init.zeros_(self.norm3.weight)

    def forward(self, x):
        # shortcut branch
        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        # residual
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.drop_path(x)
        x = self.act3(x + shortcut)
        return x


class DownsampleConv(nn.Module):
    def __init__(
        self,
        in_chs,
        out_chs,
        stride=1,
        dilation=1,
        first_dilation=None,
        preact=True,
        conv_layer=None,
        norm_layer=None,
    ):
        super(DownsampleConv, self).__init__()
        self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
        self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)

    def forward(self, x):
        return self.norm(self.conv(x))


class DownsampleAvg(nn.Module):
    def __init__(
        self,
        in_chs,
        out_chs,
        stride=1,
        dilation=1,
        first_dilation=None,
        preact=True,
        conv_layer=None,
        norm_layer=None,
    ):
        """AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
        super(DownsampleAvg, self).__init__()
        avg_stride = stride if dilation == 1 else 1
        if stride > 1 or dilation > 1:
            avg_pool_fn = (
                AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
            )
            self.pool = avg_pool_fn(
                2, avg_stride, ceil_mode=True, count_include_pad=False
            )
        else:
            self.pool = nn.Identity()
        self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
        self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)

    def forward(self, x):
        return self.norm(self.conv(self.pool(x)))


class ResNetStage(nn.Module):
    """ResNet Stage."""

    def __init__(
        self,
        in_chs,
        out_chs,
        stride,
        dilation,
        depth,
        bottle_ratio=0.25,
        groups=1,
        avg_down=False,
        block_dpr=None,
        block_fn=PreActBottleneck,
        act_layer=None,
        conv_layer=None,
        norm_layer=None,
        **block_kwargs,
    ):
        super(ResNetStage, self).__init__()
        first_dilation = 1 if dilation in (1, 2) else 2
        layer_kwargs = dict(
            act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer
        )
        proj_layer = DownsampleAvg if avg_down else DownsampleConv
        prev_chs = in_chs
        self.blocks = nn.Sequential()
        for block_idx in range(depth):
            drop_path_rate = block_dpr[block_idx] if block_dpr else 0.0
            stride = stride if block_idx == 0 else 1
            self.blocks.add_module(
                str(block_idx),
                block_fn(
                    prev_chs,
                    out_chs,
                    stride=stride,
                    dilation=dilation,
                    bottle_ratio=bottle_ratio,
                    groups=groups,
                    first_dilation=first_dilation,
                    proj_layer=proj_layer,
                    drop_path_rate=drop_path_rate,
                    **layer_kwargs,
                    **block_kwargs,
                ),
            )
            prev_chs = out_chs
            first_dilation = dilation
            proj_layer = None

    def forward(self, x):
        x = self.blocks(x)
        return x


def is_stem_deep(stem_type):
    return any([s in stem_type for s in ("deep", "tiered")])


def create_resnetv2_stem(
    in_chs,
    out_chs=64,
    stem_type="",
    preact=True,
    conv_layer=StdConv2d,
    norm_layer=partial(GroupNormAct, num_groups=32),
):
    stem = OrderedDict()
    assert stem_type in (
        "",
        "fixed",
        "same",
        "deep",
        "deep_fixed",
        "deep_same",
        "tiered",
    )

    # NOTE conv padding mode can be changed by overriding the conv_layer def
    if is_stem_deep(stem_type):
        # A 3 deep 3x3  conv stack as in ResNet V1D models
        if "tiered" in stem_type:
            stem_chs = (3 * out_chs // 8, out_chs // 2)  # 'T' resnets in resnet.py
        else:
            stem_chs = (out_chs // 2, out_chs // 2)  # 'D' ResNets
        stem["conv1"] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2)
        stem["norm1"] = norm_layer(stem_chs[0])
        stem["conv2"] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
        stem["norm2"] = norm_layer(stem_chs[1])
        stem["conv3"] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1)
        if not preact:
            stem["norm3"] = norm_layer(out_chs)
    else:
        # The usual 7x7 stem conv
        stem["conv"] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
        if not preact:
            stem["norm"] = norm_layer(out_chs)

    if "fixed" in stem_type:
        # 'fixed' SAME padding approximation that is used in BiT models
        stem["pad"] = nn.ConstantPad2d(1, 0.0)
        stem["pool"] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
    elif "same" in stem_type:
        # full, input size based 'SAME' padding, used in ViT Hybrid model
        stem["pool"] = create_pool2d("max", kernel_size=3, stride=2, padding="same")
    else:
        # the usual PyTorch symmetric padding
        stem["pool"] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    return nn.Sequential(stem)


class ResNetV2(nn.Module):
    """Implementation of Pre-activation (v2) ResNet mode."""

    def __init__(
        self,
        layers,
        channels=(256, 512, 1024, 2048),
        num_classes=1000,
        in_chans=3,
        global_pool="avg",
        output_stride=32,
        width_factor=1,
        stem_chs=64,
        stem_type="",
        avg_down=False,
        preact=True,
        act_layer=nn.ReLU,
        conv_layer=StdConv2d,
        norm_layer=partial(GroupNormAct, num_groups=32),
        drop_rate=0.0,
        drop_path_rate=0.0,
        zero_init_last=True,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.drop_rate = drop_rate
        wf = width_factor

        self.feature_info = []
        stem_chs = make_div(stem_chs * wf)
        self.stem = create_resnetv2_stem(
            in_chans,
            stem_chs,
            stem_type,
            preact,
            conv_layer=conv_layer,
            norm_layer=norm_layer,
        )
        stem_feat = (
            ("stem.conv3" if is_stem_deep(stem_type) else "stem.conv")
            if preact
            else "stem.norm"
        )
        self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))

        prev_chs = stem_chs
        curr_stride = 4
        dilation = 1
        block_dprs = [
            x.tolist()
            for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)
        ]
        block_fn = PreActBottleneck if preact else Bottleneck
        self.stages = nn.Sequential()
        for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
            out_chs = make_div(c * wf)
            stride = 1 if stage_idx == 0 else 2
            if curr_stride >= output_stride:
                dilation *= stride
                stride = 1
            stage = ResNetStage(
                prev_chs,
                out_chs,
                stride=stride,
                dilation=dilation,
                depth=d,
                avg_down=avg_down,
                act_layer=act_layer,
                conv_layer=conv_layer,
                norm_layer=norm_layer,
                block_dpr=bdpr,
                block_fn=block_fn,
            )
            prev_chs = out_chs
            curr_stride *= stride
            self.feature_info += [
                dict(
                    num_chs=prev_chs,
                    reduction=curr_stride,
                    module=f"stages.{stage_idx}",
                )
            ]
            self.stages.add_module(str(stage_idx), stage)

        self.num_features = prev_chs
        self.norm = norm_layer(self.num_features) if preact else nn.Identity()
        self.head = ClassifierHead(
            self.num_features,
            num_classes,
            pool_type=global_pool,
            drop_rate=self.drop_rate,
            use_conv=True,
        )

        self.init_weights(zero_init_last=zero_init_last)

    def init_weights(self, zero_init_last=True):
        named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

    @torch.jit.ignore()
    def load_pretrained(self, checkpoint_path, prefix="resnet/"):
        _load_weights(self, checkpoint_path, prefix)

    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes, global_pool="avg"):
        self.num_classes = num_classes
        self.head = ClassifierHead(
            self.num_features,
            num_classes,
            pool_type=global_pool,
            drop_rate=self.drop_rate,
            use_conv=True,
        )

    def forward_features(self, x):
        x = self.stem(x)
        x = self.stages(x)
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


def _init_weights(module: nn.Module, name: str = "", zero_init_last=True):
    if isinstance(module, nn.Linear) or (
        "head.fc" in name and isinstance(module, nn.Conv2d)
    ):
        nn.init.normal_(module.weight, mean=0.0, std=0.01)
        nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)
    elif zero_init_last and hasattr(module, "zero_init_last"):
        module.zero_init_last()


@torch.no_grad()
def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = "resnet/"):
    import numpy as np

    def t2p(conv_weights):
        """Possibly convert HWIO to OIHW."""
        if conv_weights.ndim == 4:
            conv_weights = conv_weights.transpose([3, 2, 0, 1])
        return torch.from_numpy(conv_weights)

    weights = np.load(checkpoint_path)
    stem_conv_w = adapt_input_conv(
        model.stem.conv.weight.shape[1],
        t2p(weights[f"{prefix}root_block/standardized_conv2d/kernel"]),
    )
    model.stem.conv.weight.copy_(stem_conv_w)
    model.norm.weight.copy_(t2p(weights[f"{prefix}group_norm/gamma"]))
    model.norm.bias.copy_(t2p(weights[f"{prefix}group_norm/beta"]))
    if (
        isinstance(getattr(model.head, "fc", None), nn.Conv2d)
        and model.head.fc.weight.shape[0]
        == weights[f"{prefix}head/conv2d/kernel"].shape[-1]
    ):
        model.head.fc.weight.copy_(t2p(weights[f"{prefix}head/conv2d/kernel"]))
        model.head.fc.bias.copy_(t2p(weights[f"{prefix}head/conv2d/bias"]))
    for i, (sname, stage) in enumerate(model.stages.named_children()):
        for j, (bname, block) in enumerate(stage.blocks.named_children()):
            cname = "standardized_conv2d"
            block_prefix = f"{prefix}block{i + 1}/unit{j + 1:02d}/"
            block.conv1.weight.copy_(t2p(weights[f"{block_prefix}a/{cname}/kernel"]))
            block.conv2.weight.copy_(t2p(weights[f"{block_prefix}b/{cname}/kernel"]))
            block.conv3.weight.copy_(t2p(weights[f"{block_prefix}c/{cname}/kernel"]))
            block.norm1.weight.copy_(t2p(weights[f"{block_prefix}a/group_norm/gamma"]))
            block.norm2.weight.copy_(t2p(weights[f"{block_prefix}b/group_norm/gamma"]))
            block.norm3.weight.copy_(t2p(weights[f"{block_prefix}c/group_norm/gamma"]))
            block.norm1.bias.copy_(t2p(weights[f"{block_prefix}a/group_norm/beta"]))
            block.norm2.bias.copy_(t2p(weights[f"{block_prefix}b/group_norm/beta"]))
            block.norm3.bias.copy_(t2p(weights[f"{block_prefix}c/group_norm/beta"]))
            if block.downsample is not None:
                w = weights[f"{block_prefix}a/proj/{cname}/kernel"]
                block.downsample.conv.weight.copy_(t2p(w))


def _create_resnetv2(variant, pretrained=False, **kwargs):
    feature_cfg = dict(flatten_sequential=True)
    return build_model_with_cfg(
        ResNetV2,
        variant,
        pretrained,
        default_cfg=default_cfgs[variant],
        feature_cfg=feature_cfg,
        pretrained_custom_load=True,
        **kwargs,
    )


def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
    return _create_resnetv2(
        variant,
        pretrained=pretrained,
        stem_type="fixed",
        conv_layer=partial(StdConv2d, eps=1e-8),
        **kwargs,
    )


@register_model
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_50x1_bitm",
        pretrained=pretrained,
        layers=[3, 4, 6, 3],
        width_factor=1,
        **kwargs,
    )


@register_model
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_50x3_bitm",
        pretrained=pretrained,
        layers=[3, 4, 6, 3],
        width_factor=3,
        **kwargs,
    )


@register_model
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_101x1_bitm",
        pretrained=pretrained,
        layers=[3, 4, 23, 3],
        width_factor=1,
        **kwargs,
    )


@register_model
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_101x3_bitm",
        pretrained=pretrained,
        layers=[3, 4, 23, 3],
        width_factor=3,
        **kwargs,
    )


@register_model
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_152x2_bitm",
        pretrained=pretrained,
        layers=[3, 8, 36, 3],
        width_factor=2,
        **kwargs,
    )


@register_model
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_152x4_bitm",
        pretrained=pretrained,
        layers=[3, 8, 36, 3],
        width_factor=4,
        **kwargs,
    )


@register_model
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_50x1_bitm_in21k",
        pretrained=pretrained,
        num_classes=kwargs.pop("num_classes", 21843),
        layers=[3, 4, 6, 3],
        width_factor=1,
        **kwargs,
    )


@register_model
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_50x3_bitm_in21k",
        pretrained=pretrained,
        num_classes=kwargs.pop("num_classes", 21843),
        layers=[3, 4, 6, 3],
        width_factor=3,
        **kwargs,
    )


@register_model
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2(
        "resnetv2_101x1_bitm_in21k",
        pretrained=pretrained,
        num_classes=kwargs.pop("num_classes", 21843),
        layers=[3, 4, 23, 3],
        width_factor=1,
        **kwargs,
    )


@register_model
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_101x3_bitm_in21k",
        pretrained=pretrained,
        num_classes=kwargs.pop("num_classes", 21843),
        layers=[3, 4, 23, 3],
        width_factor=3,
        **kwargs,
    )


@register_model
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_152x2_bitm_in21k",
        pretrained=pretrained,
        num_classes=kwargs.pop("num_classes", 21843),
        layers=[3, 8, 36, 3],
        width_factor=2,
        **kwargs,
    )


@register_model
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
    return _create_resnetv2_bit(
        "resnetv2_152x4_bitm_in21k",
        pretrained=pretrained,
        num_classes=kwargs.pop("num_classes", 21843),
        layers=[3, 8, 36, 3],
        width_factor=4,
        **kwargs,
    )


@register_model
def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs):
    """ResNetV2-50x1-BiT Distilled
    Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
    """
    return _create_resnetv2_bit(
        "resnetv2_50x1_bit_distilled",
        pretrained=pretrained,
        layers=[3, 4, 6, 3],
        width_factor=1,
        **kwargs,
    )


@register_model
def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs):
    """ResNetV2-152x2-BiT Teacher
    Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
    """
    return _create_resnetv2_bit(
        "resnetv2_152x2_bit_teacher",
        pretrained=pretrained,
        layers=[3, 8, 36, 3],
        width_factor=2,
        **kwargs,
    )


@register_model
def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
    """ResNetV2-152xx-BiT Teacher @ 384x384
    Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
    """
    return _create_resnetv2_bit(
        "resnetv2_152x2_bit_teacher_384",
        pretrained=pretrained,
        layers=[3, 8, 36, 3],
        width_factor=2,
        **kwargs,
    )


@register_model
def resnetv2_50(pretrained=False, **kwargs):
    return _create_resnetv2(
        "resnetv2_50",
        pretrained=pretrained,
        layers=[3, 4, 6, 3],
        conv_layer=create_conv2d,
        norm_layer=BatchNormAct2d,
        **kwargs,
    )


@register_model
def resnetv2_50d(pretrained=False, **kwargs):
    return _create_resnetv2(
        "resnetv2_50d",
        pretrained=pretrained,
        layers=[3, 4, 6, 3],
        conv_layer=create_conv2d,
        norm_layer=BatchNormAct2d,
        stem_type="deep",
        avg_down=True,
        **kwargs,
    )


@register_model
def resnetv2_50t(pretrained=False, **kwargs):
    return _create_resnetv2(
        "resnetv2_50t",
        pretrained=pretrained,
        layers=[3, 4, 6, 3],
        conv_layer=create_conv2d,
        norm_layer=BatchNormAct2d,
        stem_type="tiered",
        avg_down=True,
        **kwargs,
    )


@register_model
def resnetv2_101(pretrained=False, **kwargs):
    return _create_resnetv2(
        "resnetv2_101",
        pretrained=pretrained,
        layers=[3, 4, 23, 3],
        conv_layer=create_conv2d,
        norm_layer=BatchNormAct2d,
        **kwargs,
    )


@register_model
def resnetv2_101d(pretrained=False, **kwargs):
    return _create_resnetv2(
        "resnetv2_101d",
        pretrained=pretrained,
        layers=[3, 4, 23, 3],
        conv_layer=create_conv2d,
        norm_layer=BatchNormAct2d,
        stem_type="deep",
        avg_down=True,
        **kwargs,
    )


@register_model
def resnetv2_152(pretrained=False, **kwargs):
    return _create_resnetv2(
        "resnetv2_152",
        pretrained=pretrained,
        layers=[3, 8, 36, 3],
        conv_layer=create_conv2d,
        norm_layer=BatchNormAct2d,
        **kwargs,
    )


@register_model
def resnetv2_152d(pretrained=False, **kwargs):
    return _create_resnetv2(
        "resnetv2_152d",
        pretrained=pretrained,
        layers=[3, 8, 36, 3],
        conv_layer=create_conv2d,
        norm_layer=BatchNormAct2d,
        stem_type="deep",
        avg_down=True,
        **kwargs,
    )


# @register_model
# def resnetv2_50ebd(pretrained=False, **kwargs):
#     # FIXME for testing w/ TPU + PyTorch XLA
#     return _create_resnetv2(
#         'resnetv2_50d', pretrained=pretrained,
#         layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
#         stem_type='deep', avg_down=True, **kwargs)
#
#
# @register_model
# def resnetv2_50esd(pretrained=False, **kwargs):
#     # FIXME for testing w/ TPU + PyTorch XLA
#     return _create_resnetv2(
#         'resnetv2_50d', pretrained=pretrained,
#         layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
#         stem_type='deep', avg_down=True, **kwargs)
